#############################################################################################
# NKPH solution methods
#############################################################################################


using DataStructures
using NLsolve


#########################
# solver struct
struct NKPHSolver
    # combine exog/endog parameter dictionary
    # for easy parameter lookup and avoiding repeated computation
    x_dict::OrderedDict{Symbol,Float64}
    dx_dict::OrderedDict{Symbol,Float64}

    # mapping from parameters to model objects (model-specific)
    map_params!::Function
    deriv_params!::Function
    deriv_AR!::Function
    #solve_risk_neutral!::Function

    # endogenous parameters
    M::Matrix{Float64}
    AR_hat::Vector{Float64}

    # state dynamics
    sdyn::StateDynamics
    # model objects; store in (mutable) arrays
    Upsilon::Matrix{Float64}
    sigma::Matrix{Float64}
    Sigma::Matrix{Float64}
    aSigma::Matrix{Float64}

    # a/alpha coefs stored in 1-dim mutable array
    a::Vector{Float64}
    alpha0::Vector{Float64}
    alpha1::Vector{Float64}
    alpha0_til::Vector{Float64}
    alpha1_til::Vector{Float64}
    eta0::Vector{Float64}
    eta1::Vector{Float64}

    Theta0::Vector{Float64}
    Theta1::Vector{Float64}
    Theta0_til::Vector{Float64}
    Theta1_til::Vector{Float64}

    ei_vec::Vector{Float64}
    ed_vec::Vector{Float64}

    # laplace objects (updated inplace)
    Lr::lapl_y
    Lr_til::lapl_y
    LX::lapl_X
    LX_til::lapl_X
    LY::lapl_Y
    LY_til::lapl_Y

    # additional fixed macro parameters (model-specific)
    aux_macro_params::OrderedDict{Symbol,Float64}

    # root terms
    yR_hat::Vector{Float64}
    GammaT_M::Matrix{Float64}
    M_lapl_terms::Matrix{Float64}
    M_root::Vector{Float64}
    AR_root::Vector{Float64}
    # combined root vector
    root_vec::Vector{Float64}

    # derivative objects
    # final dim N_params_all (for deriv ojbects which are used repeatedly)
    # root derivatives
    dM_root::Matrix{Float64}
    dAR_root::Matrix{Float64}
    droot_vec::Matrix{Float64}

    # additional derivative objects (updated inplace)
    dUpsilon::Array{Float64, 3}
    dGamma::Array{Float64, 3}
    dsigma::Array{Float64, 3}
    dSigma::Array{Float64, 3}
    daSigma::Array{Float64, 3}

    # a/alpha coefs stored in 1-dim mutable array
    dalpha0::Vector{Float64}
    dalpha1::Vector{Float64}
    dalpha0_til::Vector{Float64}
    dalpha1_til::Vector{Float64}
    deta0::Vector{Float64}
    deta1::Vector{Float64}

    dTheta0::Matrix{Float64}
    dTheta1::Matrix{Float64}
    dTheta0_til::Matrix{Float64}
    dTheta1_til::Matrix{Float64}

    dyr::Matrix{Float64}
    dyr_til::Matrix{Float64}
    dX::Array{Float64, 3}
    dX_til::Array{Float64, 3}
    dY::Array{Float64, 3}
    dY_til::Array{Float64, 3}

    # for keeping track of param -> model mapping
    param_model_map::Vector{OrderedDict{Symbol,Bool}}

    # integer count varialbes
    J::Int
    N::Int
    N_params::Int
    N_params_all::Int
end
function NKPHSolver(
        map_params!::Function,
        deriv_params!::Function,
        deriv_AR!::Function,
        param_names::Vector{Symbol},
        N::Int,
        J::Int,
        sigma::Matrix{Float64},
        a_val::Float64,
        alpha0_val::Float64,
        alpha1_val::Float64,
        alpha0_til_val::Float64,
        alpha1_til_val::Float64,
        eta0_val::Float64,
        eta1_val::Float64,
        Theta0::Vector{Float64},
        Theta1::Vector{Float64},
        Theta0_til::Vector{Float64},
        Theta1_til::Vector{Float64},
        ei_vec::Vector{Float64},
        ed_vec::Vector{Float64},
        N_params::Int,
    )
    # initialize model
    N_params_all = N_params + J^2 + J
    # create parameter dicts
    x_dict = OrderedDict{Symbol, Float64}()
    _keys = Vector{Symbol}(undef, N_params_all)
    _keys_M = Matrix{Symbol}(undef, J, J)
    for j1=1:J
        for j2=1:J
            _keys_M[j1, j2] = Symbol("M_$(j1)_$(j2)")
        end
    end
    _keys_AR = [Symbol("AR_$(j)") for j=1:J]
    _keys = [param_names; vec(_keys_M); _keys_AR]
    for _key in _keys
        x_dict[_key] = 0.
    end
    dx_dict = deepcopy(x_dict)

    # initialize matrices
    M = zeros(J, J)
    AR_hat = zeros(J)
    Upsilon = zeros(N, N)
    sdyn = StateDynamics(Upsilon, J)
    Sigma = sigma * sigma'
    aSigma = a_val .* Sigma

    # mutable arrays
    a = [a_val]
    alpha0 = [alpha0_val]
    alpha1 = [alpha1_val]
    alpha0_til = [alpha0_til_val]
    alpha1_til = [alpha1_til_val]
    eta0 = [eta0_val]
    eta1 = [eta1_val]

    # initialize lapl objects
    Lr = lapl_y(M, eta1, ei_vec, zeros(J))
    Lr_til = lapl_y(M, eta1, ei_vec, -ed_vec)
    # LX
    LX = lapl_X(M, alpha1, ei_vec, zeros(J))
    LX_til = lapl_X(M, alpha1_til, ei_vec, -ed_vec)
    # LY
    LY = lapl_Y(M, Theta1, ei_vec, zeros(J))
    LY_til = lapl_Y(M, Theta1_til, ei_vec, -ed_vec)

    # initialize auxiliary macro dict
    aux_macro_params = OrderedDict{Symbol, Float64}()

    # root terms
    yR_hat = zeros(J)
    GammaT_M = zeros(J, J)
    M_lapl_terms = zeros(J, J)
    M_root = ones(J^2)
    AR_root = ones(J)
    root_vec = ones(J^2 + J)

    # derivatives objects
    dM_root = zeros(J^2, N_params_all)
    dAR_root = zeros(J, N_params_all)
    droot_vec = zeros(J^2 + J, N_params_all)

    # additional deriv objects
    dUpsilon = zeros(N, N, N_params_all)
    dGamma = zeros(J, J, N_params_all)
    dsigma = zeros(J, J, N_params_all)
    dSigma = zeros(J, J, N_params_all)
    daSigma = zeros(J, J, N_params_all)

    # a/alpha coefs stored in 1-dim mutable array
    dalpha0 = zeros(N_params_all)
    dalpha1 = zeros(N_params_all)
    dalpha0_til = zeros(N_params_all)
    dalpha1_til = zeros(N_params_all)
    deta0 = zeros(N_params_all)
    deta1 = zeros(N_params_all)

    dTheta0 = zeros(J, N_params_all)
    dTheta1 = zeros(J, N_params_all)
    dTheta0_til = zeros(J, N_params_all)
    dTheta1_til = zeros(J, N_params_all)

    dyr = zeros(J, N_params_all)
    dyr_til = zeros(J, N_params_all)
    dX = zeros(J, J, N_params_all)
    dX_til = zeros(J, J, N_params_all)
    dY = zeros(J, J, N_params_all)
    dY_til = zeros(J, J, N_params_all)

    param_model_map = Vector{OrderedDict{Symbol,Bool}}(undef, N_params_all)
    for i=1:N_params_all
        pmap = OrderedDict{Symbol,Bool}()
        pmap[:Upsilon] = false
        pmap[:sigma] = false
        pmap[:alpha0] = false
        pmap[:alpha1] = false
        pmap[:alpha0_til] = false
        pmap[:alpha1_til] = false
        pmap[:eta0] = false
        pmap[:eta1] = false
        pmap[:Theta0] = false
        pmap[:Theta1] = false
        pmap[:Theta0_til] = false
        pmap[:Theta1_til] = false
        param_model_map[i] = pmap
    end

    return NKPHSolver(
        x_dict::OrderedDict{Symbol,Float64},
        dx_dict::OrderedDict{Symbol,Float64},
        map_params!::Function,
        deriv_params!::Function,
        deriv_AR!::Function,
        #solve_risk_neutral!::Function,
        M::Matrix{Float64},
        AR_hat::Vector{Float64},
        sdyn::StateDynamics,
        Upsilon::Matrix{Float64},
        sigma::Matrix{Float64},
        Sigma::Matrix{Float64},
        aSigma::Matrix{Float64},
        a::Vector{Float64},
        alpha0::Vector{Float64},
        alpha1::Vector{Float64},
        alpha0_til::Vector{Float64},
        alpha1_til::Vector{Float64},
        eta0::Vector{Float64},
        eta1::Vector{Float64},
        Theta0::Vector{Float64},
        Theta1::Vector{Float64},
        Theta0_til::Vector{Float64},
        Theta1_til::Vector{Float64},
        ei_vec::Vector{Float64},
        ed_vec::Vector{Float64},
        Lr::lapl_y,
        Lr_til::lapl_y,
        LX::lapl_X,
        LX_til::lapl_X,
        LY::lapl_Y,
        LY_til::lapl_Y,
        aux_macro_params::OrderedDict{Symbol,Float64},
        yR_hat::Vector{Float64},
        GammaT_M::Matrix{Float64},
        M_lapl_terms::Matrix{Float64},
        M_root::Vector{Float64},
        AR_root::Vector{Float64},
        root_vec::Vector{Float64},
        # derivs
        dM_root::Matrix{Float64},
        dAR_root::Matrix{Float64},
        droot_vec::Matrix{Float64},
        dUpsilon::Array{Float64, 3},
        dGamma::Array{Float64, 3},
        dsigma::Array{Float64, 3},
        dSigma::Array{Float64, 3},
        daSigma::Array{Float64, 3},
        dalpha0::Vector{Float64},
        dalpha1::Vector{Float64},
        dalpha0_til::Vector{Float64},
        dalpha1_til::Vector{Float64},
        deta0::Vector{Float64},
        deta1::Vector{Float64},
        dTheta0::Matrix{Float64},
        dTheta1::Matrix{Float64},
        dTheta0_til::Matrix{Float64},
        dTheta1_til::Matrix{Float64},
        dyr::Matrix{Float64},
        dyr_til::Matrix{Float64},
        dX::Array{Float64, 3},
        dX_til::Array{Float64, 3},
        dY::Array{Float64, 3},
        dY_til::Array{Float64, 3},
        param_model_map::Vector{OrderedDict{Symbol,Bool}},
        J::Int,
        N::Int,
        N_params::Int,
        N_params_all::Int,
    )
end

####################################
# updating model:
# 1. update AR_hat/M (generic)
# 2. update dynamics/habitat params (model-specific)
# 3. compute additional habitat/dynamics/root objects (generic)

function compute_nkphsolvr_objects!(nkphsolvr::NKPHSolver)::Nothing
    # compute (inplace) all solver objects (after updating params)
    # pull out objects
    M, AR_hat = nkphsolvr.M, nkphsolvr.AR_hat
    Gamma = nkphsolvr.sdyn.Gamma
    ei_vec, sigma, a = nkphsolvr.ei_vec, nkphsolvr.sigma, nkphsolvr.a
    alpha0, alpha0_til = nkphsolvr.alpha0, nkphsolvr.alpha0_til
    Theta0, Theta0_til = nkphsolvr.Theta0, nkphsolvr.Theta0_til
    eta0 = nkphsolvr.eta0
    Sigma, aSigma = nkphsolvr.Sigma, nkphsolvr.aSigma
    LX, LX_til = nkphsolvr.LX, nkphsolvr.LX_til
    LY, LY_til = nkphsolvr.LY, nkphsolvr.LY_til
    Lr, Lr_til = nkphsolvr.Lr, nkphsolvr.Lr_til
    yR_hat, GammaT_M, M_lapl_terms = nkphsolvr.yR_hat, nkphsolvr.GammaT_M, nkphsolvr.M_lapl_terms

    # compute state dynamics matrix and update Gamma (inplace)
    update_state_dynamics!(nkphsolvr.sdyn)

    # update solver objects
    # covariance
    Sigma[:, :] = sigma * sigma'
    aSigma[:, :] = a[1].* Sigma
    # LX objects
    update_lapl_X!(LX)
    update_lapl_X!(LX_til)
    # LY objects
    update_lapl_Y!(LY)
    update_lapl_Y!(LY_til)
    # eff borrowing rate objects
    update_lapl_y!(Lr)
    update_lapl_y!(Lr_til)
    yR_hat[:] = eta0[1] .* Lr.y + (1 - eta0[1]) .* Lr_til.y
    GammaT_M[:, :] = Gamma' - M

    # root objects
    # use Diagonal struct for efficient computations
    DT0 = Diagonal(Theta0)
    DT0_til = Diagonal(Theta0_til)
    M_lapl_terms[:, :] = (DT0 * LY.Y' - alpha0[1] .* LX.X) +
        (DT0_til * LY_til.Y' - alpha0_til[1] .* LX_til.X)
    nkphsolvr.M_root[:] = vec( GammaT_M - M_lapl_terms * aSigma )

    nkphsolvr.AR_root[:] = GammaT_M * yR_hat + ei_vec - AR_hat

    # combined
    nkphsolvr.root_vec[:] = [nkphsolvr.AR_root; nkphsolvr.M_root]
    return nothing
end
# wrapper
function update_nkphsolvr!(x_arr::Vector{Float64}, nkphsolvr::NKPHSolver)::Nothing
    # update all parameter and M values, compute root objects
    J, N_params = nkphsolvr.J, nkphsolvr.N_params
    # avoid duplicated computations
    if x_arr != nkphsolvr.x_dict.vals
        nkphsolvr.x_dict.vals[:] = x_arr
        #p_arr = x_arr[1:N_params]
        # update endogenous params
        AR_vec = x_arr[N_params+1:N_params+J]
        m_vec = x_arr[end-J^2+1:end]
        nkphsolvr.AR_hat[:] = AR_vec
        nkphsolvr.M[:, :] = reshape(m_vec, nkphsolvr.J, nkphsolvr.J)
        # update exogenous params (model-specific mapping)
        nkphsolvr.map_params!(nkphsolvr)
        compute_nkphsolvr_objects!(nkphsolvr)
    end
    return nothing
end
# only exog params
function update_nkphsolvr_params!(p_arr::Vector{Float64}, nkphsolvr::NKPHSolver)::Nothing
    # update all parameter and M values, compute root objects
    nkphsolvr.x_dict.vals[1:nkphsolvr.N_params] = p_arr
    return nothing
end

#########################
# deriv computations
# wrappers to map derivs correctly
function _m_idx_to_deriv_idx(nkphsolvr::NKPHSolver, j_idx::Int, k_idx::Int)::Int
    return nkphsolvr.N_params + nkphsolvr.J + (j_idx + nkphsolvr.J * (k_idx - 1))
end
function _AR_idx_to_deriv_idx(nkphsolvr::NKPHSolver, j_idx::Int)::Int
    return nkphsolvr.N_params + j_idx
end

function deriv_nkphsolvr_objects_wrtM!(nkphsolvr::NKPHSolver, j_idx::Int, k_idx::Int)::Nothing
    # compute root object derivs wrt m_{j,k} (generic)
    # pull out objects
    J, N_params = nkphsolvr.J, nkphsolvr.N_params
    alpha0, alpha0_til = nkphsolvr.alpha0, nkphsolvr.alpha0_til
    Theta0, Theta0_til = nkphsolvr.Theta0, nkphsolvr.Theta0_til
    eta0 = nkphsolvr.eta0
    aSigma = nkphsolvr.aSigma
    LX, LX_til = nkphsolvr.LX, nkphsolvr.LX_til
    LY, LY_til = nkphsolvr.LY, nkphsolvr.LY_til
    Lr, Lr_til = nkphsolvr.Lr, nkphsolvr.Lr_til
    yR_hat, GammaT_M, M_lapl_terms = nkphsolvr.yR_hat, nkphsolvr.GammaT_M, nkphsolvr.M_lapl_terms

    # update deriv objects (inplace using the correct deriv idx)
    deriv_idx = _m_idx_to_deriv_idx(nkphsolvr, j_idx, k_idx)
    dX = @view nkphsolvr.dX[:, :, deriv_idx]
    dX_til = @view nkphsolvr.dX_til[:, :, deriv_idx]
    dY = @view nkphsolvr.dY[:, :, deriv_idx]
    dY_til = @view nkphsolvr.dY_til[:, :, deriv_idx]
    dyr = @view nkphsolvr.dyr[:, deriv_idx]
    dyr_til = @view nkphsolvr.dyr_til[:, deriv_idx]
    dM_root = @view nkphsolvr.dM_root[:, deriv_idx]
    dAR_root = @view nkphsolvr.dAR_root[:, deriv_idx]

    # LX objects
    deriv_lapl_X_wrtM!(dX, LX, j_idx, k_idx)
    deriv_lapl_X_wrtM!(dX_til, LX_til, j_idx, k_idx)
    # LY objects
    deriv_lapl_Y_wrtM!(dY, LY, j_idx, k_idx)
    deriv_lapl_Y_wrtM!(dY_til, LY_til, j_idx, k_idx)
    # Lr objects
    deriv_lapl_y_wrtM!(dyr, Lr, j_idx, k_idx)
    deriv_lapl_y_wrtM!(dyr_til, Lr_til, j_idx, k_idx)

    # root deriv objects
    DT0 = Diagonal(Theta0)
    DT0_til = Diagonal(Theta0_til)
    dM_lapl_terms = (DT0 * dY' - alpha0[1] .* dX) +
        (DT0_til * dY_til' - alpha0_til[1] .* dX_til)
    _dM_root = -dM_lapl_terms * aSigma
    # dM term
    _dM_root[j_idx, k_idx] -= 1.
    dM_root[:] = vec(_dM_root)

    dAR_root[:] = GammaT_M * (eta0[1] .* dyr + (1 - eta0[1]) .* dyr_til)
    # dM term
    dAR_root[j_idx] -= yR_hat[k_idx]

    # combined
    nkphsolvr.droot_vec[:, deriv_idx] = [dAR_root; dM_root]
    return nothing
end

function deriv_nkphsolvr_objects_wrtAR!(nkphsolvr::NKPHSolver, j_idx)::Nothing
    # compute root object derivs wrt AR_hat_{j}
    # call model-specific param deriv function, which sets deriv model objects
    nkphsolvr.deriv_AR!(nkphsolvr, j_idx)

    # pull out objects
    sdyn = nkphsolvr.sdyn
    yR_hat = nkphsolvr.yR_hat

    # update deriv objects (inplace using the correct deriv idx)
    deriv_idx = _AR_idx_to_deriv_idx(nkphsolvr, j_idx)
    dUpsilon = @view nkphsolvr.dUpsilon[:, :, deriv_idx]
    dGamma = @view nkphsolvr.dGamma[:, :, deriv_idx]
    dM_root = @view nkphsolvr.dM_root[:, deriv_idx]
    dAR_root = @view nkphsolvr.dAR_root[:, deriv_idx]

    # state dynamics Gamma deriv
    deriv_state_dynamics_wrtUpsilon!(dGamma, sdyn, dUpsilon)

    # root deriv objects
    dM_root[:] = vec(dGamma')
    dAR_root[:] = dGamma' * yR_hat
    # dAR term
    dAR_root[j_idx] -= 1.

    # combined
    nkphsolvr.droot_vec[:, deriv_idx] = [dAR_root; dM_root]
    return nothing
end

function deriv_nkphsolvr_objects_wrtparam!(nkphsolvr::NKPHSolver, param_idx)::Nothing
    # compute root object derivs wrt param_{idx}
    # call model-specific param deriv function, which sets deriv model objects
    nkphsolvr.deriv_params!(nkphsolvr, param_idx)

    # pull out objects
    sdyn = nkphsolvr.sdyn
    sigma, aSigma, a = nkphsolvr.sigma, nkphsolvr.aSigma, nkphsolvr.a
    LX, LX_til = nkphsolvr.LX, nkphsolvr.LX_til
    LY, LY_til = nkphsolvr.LY, nkphsolvr.LY_til
    Lr, Lr_til = nkphsolvr.Lr, nkphsolvr.Lr_til
    yR_hat, GammaT_M, M_lapl_terms = nkphsolvr.yR_hat, nkphsolvr.GammaT_M, nkphsolvr.M_lapl_terms
    # deriv objects
    # update deriv objects (inplace using the correct deriv idx)
    pmap = nkphsolvr.param_model_map[param_idx]
    dsigma = @view nkphsolvr.dsigma[:, :, param_idx]
    dSigma = @view nkphsolvr.dSigma[:, :, param_idx]
    daSigma = @view nkphsolvr.daSigma[:, :, param_idx]
    dUpsilon = @view nkphsolvr.dUpsilon[:, :, param_idx]
    dGamma = @view nkphsolvr.dGamma[:, :, param_idx]
    dalpha0 = @view nkphsolvr.dalpha0[param_idx]
    dalpha0_til = @view nkphsolvr.dalpha0_til[param_idx]
    deta0 = @view nkphsolvr.deta0[param_idx]
    dM_root = @view nkphsolvr.dM_root[:, param_idx]
    dAR_root = @view nkphsolvr.dAR_root[:, param_idx]

    # update root objects iteratively
    dM_root[:] .= 0.
    dAR_root[:] .= 0.

    # check which were changed and compute root derivs
    # TODO not all parameters supported yet
    if pmap[:sigma]
        # update additional volatility deriv matrices
        dSigma[:, :] = dsigma * sigma + sigma * dsigma'
        daSigma[:, :] = a[1] .* dSigma
        dM_root[:] += vec( -M_lapl_terms * daSigma )
        # note: AR_root unchanged
    end
    if pmap[:Upsilon]
        # compute state dynamics Gamma deriv
        deriv_state_dynamics_wrtUpsilon!(dGamma, sdyn, dUpsilon)
        dM_root[:] += vec( dGamma' )
        dAR_root[:] += dGamma' * yR_hat
    end
    if pmap[:alpha0]
        # compute state dynamics Gamma deriv
        dM_root[:] += vec( dalpha0[1] .* LX.X * aSigma )
        # note: AR_root unchanged
    end
    if pmap[:alpha0_til]
        # compute state dynamics Gamma deriv
        dM_root[:] += vec( dalpha0_til[1] .* LX_til.X * aSigma )
        # note: AR_root unchanged
    end
    if pmap[:eta0]
        dAR_root[:] += GammaT_M * (Lr.y + - Lr_til.y) .* deta0[1]
        # note: M_root unchanged
    end

    # combined
    nkphsolvr.droot_vec[:, param_idx] = [dAR_root; dM_root]
    return nothing
end

# wrapper
function deriv_nkphsolvr!(nkphsolvr::NKPHSolver)::Nothing
    # derivative of root function
    # avoid repeated computations
    if nkphsolvr.x_dict.vals != nkphsolvr.dx_dict.vals
        nkphsolvr.dx_dict.vals[:] = nkphsolvr.x_dict.vals
        N_params = nkphsolvr.N_params
        J = nkphsolvr.J
        # derivs wrt params (model-specific)
        for param_idx=1:N_params
            deriv_nkphsolvr_objects_wrtparam!(nkphsolvr, param_idx)
        end
        # derivs wrt AR_hat (model-specific)
        for j_idx=1:J
            deriv_nkphsolvr_objects_wrtAR!(nkphsolvr, j_idx)
        end
        # derivs wrt M (loop order corresponds to vec(M) order)
        for k_idx=1:J
            for j_idx=1:J
                deriv_nkphsolvr_objects_wrtM!(nkphsolvr, j_idx, k_idx)
            end
        end
    end
    return nothing
end



####################
# continuation solver

function _get_risk_neutral_solution!(nkphsolvr::NKPHSolver)::Vector{Float64}
    # solve for the risk-neutral endogenous params, given the exogenous params p0
    J, N_params = nkphsolvr.J, nkphsolvr.N_params
    # implied solution  AR_hat = ei_vec
    nkphsolvr.AR_hat[:] = nkphsolvr.ei_vec
    nkphsolvr.x_dict.vals[N_params+1:N_params+J] = nkphsolvr.AR_hat
    nkphsolvr.map_params!(nkphsolvr)
    # solve for Gamma dynamics matrix to get implied solution M = Gamma^T
    update_state_dynamics!(nkphsolvr.sdyn)
    nkphsolvr.M[:, :] = copy(nkphsolvr.sdyn.Gamma')
    m0_soln = vec(nkphsolvr.M)
    nkphsolvr.x_dict.vals[end-J^2+1:end] = m0_soln
    z_soln = [nkphsolvr.AR_hat; m0_soln]
    return z_soln
end


# factory for creating root function/jacobian for continuation solver
function _get_fj_funcs(nkphsolvr::NKPHSolver)::Function
    function fj!(
            FF::Union{Vector{Float64}, Nothing},
            JJ::Union{Matrix{Float64}, Nothing},
            z_vec::Vector{Float64}
        )
        # compute root, function of M/AR_hat only
        J, N_params = nkphsolvr.J, nkphsolvr.N_params
        # update endogenous params
        AR_vec = z_vec[1:J]
        m_vec = z_vec[end-J^2+1:end]
        nkphsolvr.AR_hat[:] = AR_vec
        nkphsolvr.M[:, :] = reshape(m_vec, nkphsolvr.J, nkphsolvr.J)
        # note: still need to call map_params due to AR_hat
        nkphsolvr.map_params!(nkphsolvr)
        compute_nkphsolvr_objects!(nkphsolvr)
        # assign root/jacobian objects
        if !(FF === nothing)
            FF[:] = nkphsolvr.root_vec
        end
        if !(JJ === nothing)
            # derivs wrt AR_hat (model-specific)
            for j_idx=1:J
                deriv_nkphsolvr_objects_wrtAR!(nkphsolvr, j_idx)
            end
            # derivs wrt M (loop order corresponds to vec(M) order)
            for k_idx=1:J
                for j_idx=1:J
                    deriv_nkphsolvr_objects_wrtM!(nkphsolvr, j_idx, k_idx)
                end
            end
            JJ[:, :] = nkphsolvr.droot_vec[:, N_params+1:end]
        end
        return nothing
    end
    return fj!
end

function solve_nkphsolvr_fast!(z0_vec::Vector{Float64}, nkphsolvr::NKPHSolver;
        kwargs...)::NLsolve.SolverResults
    # solve model (M only)
    _fj! = _get_fj_funcs(nkphsolvr)
    nlsolve_res = nlsolve(only_fj!(_fj!), z0_vec; kwargs...)
    return nlsolve_res
end

function solve_nkphsolvr_continuation!(N_a::Int, nkphsolvr::NKPHSolver;
        kwargs...)::Tuple{Bool, Vector{Float64}}
    # continuation solver (M only)
    solved_model = true
    # initial m_vec/AR_vec for a=0
    a_final = nkphsolvr.a[1]
    a_arr = range(0.0, a_final, length=N_a)

    # a=0 solution: AR_hat = ei_vec and M = Gamma^T
    # must solve for Gamma from Upsilon using model-specific mapping
    z_soln = _get_risk_neutral_solution!(nkphsolvr)
    for a_val in a_arr
        # update a value
        nkphsolvr.a[1] = a_val
        nlsolve_res = solve_nkphsolvr_fast!(z_soln, nkphsolvr; kwargs...)
        if ! converged(nlsolve_res)
            solved_model = false
            nkphsolvr.a[1] = a_final
            break
        else
            z_soln[:] = nlsolve_res.zero
        end
    end
    return solved_model, z_soln
end
